use vir::messages::AstId;

use rustc_hir::HirId;
use rustc_span::SpanData;

use vir::ast::{Datatype, Dt, Fun, Function, Krate, Mode, Path, Pattern};
use vir::modes::ErasureModes;

use crate::verus_items::{DummyCaptureItem, VerusItem, VerusItems};
use rustc_hir::def_id::LocalDefId;
use rustc_mir_build_verus::verus::{
    BodyErasure, CallErasure, ExpectSpec, ExpectSpecArgs, NodeErase, VarErasure, VerusErasureCtxt,
    set_verus_aware_def_ids, set_verus_erasure_ctxt,
};
use rustc_span::Span;
use std::collections::HashMap;
use std::collections::HashSet;
use std::sync::Arc;
use vir::ast::VirErr;

#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub enum CompilableOperator {
    IntIntrinsic,
    Implies,
    RcNew,
    ArcNew,
    BoxNew,
    SmartPtrClone { is_method: bool },
    GhostExec,
    TrackedNew,
    TrackedExec,
    TrackedExecBorrow,
    TrackedGet,
    TrackedBorrow,
    TrackedBorrowMut,
    UseTypeInvariant,
    ClosureToFnProof(Mode),
    Resolve,
}

/// Information about each call in the AST (each ExprKind::Call).
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum ResolvedCall {
    /// The call is to a spec or proof function, and should be erased
    Spec,
    /// The call is to a spec or proof function, but may have proof-mode arguments
    SpecAllowProofArgs,
    /// The call is to an operator like == or + that should be compiled.
    CompilableOperator(CompilableOperator),
    /// The call is to a function, and we record the name of the function here
    /// (both unresolved and resolved), as well as an in_ghost flag.
    /// This is replaced by CallModes as soon as the modes are available.
    Call(Fun, Fun, bool),
    /// Path and variant of datatype constructor
    Ctor(Path, vir::ast::Ident),
    /// Path and variant of datatype constructor. Used for ExprKind::Struct nodes.
    BracesCtor(Path, vir::ast::Ident, Arc<Vec<vir::ast::Ident>>, bool),
    /// The call is to a dynamically computed function, and is exec
    NonStaticExec,
    /// The call is to a dynamically computed function, and is proof
    NonStaticProof(Arc<Vec<Mode>>),
}

#[derive(Clone)]
pub struct ErasureHints {
    /// Copy of the entire VIR crate that was created in the first run's HIR -> VIR transformation
    pub vir_crate: Krate,
    /// Connect expression and pattern HirId to corresponding vir AstId
    pub hir_vir_ids: Vec<(HirId, AstId)>,
    /// Details of each call in the first run's HIR
    pub resolved_calls: Vec<(HirId, SpanData, ResolvedCall)>,
    /// Details of some patterns in first run's HIR
    pub resolved_pats: Vec<(SpanData, Pattern)>,
    /// Results of mode (spec/proof/exec) inference from first run's VIR
    pub erasure_modes: ErasureModes,
    /// Modes specified directly during rust_to_vir
    pub direct_var_modes: Vec<(HirId, Mode)>,
    /// List of #[verifier(external)] functions.  (These don't appear in vir_crate,
    /// so we need to record them separately here.)
    pub external_functions: Vec<Fun>,
    /// List of function spans ignored by the verifier. These should not be erased
    pub ignored_functions: Vec<(rustc_span::def_id::DefId, SpanData)>,
    pub(crate) bodies: Vec<(LocalDefId, BodyErasure)>,
}

fn mode_to_var_erase(mode: Mode) -> VarErasure {
    match mode {
        Mode::Spec => VarErasure::Erase,
        Mode::Exec | Mode::Proof => VarErasure::Keep,
    }
}

/// Translate ResolvedCall (generated by the rust_verify HIR traversal) to CallErasure,
/// which is what the rustc_mir_build_verus fork expects.
/// REVIEW: it might simpler to skip the ResolvedCall call entirely and have the original
/// traversal generate CallErasure values.
fn resolved_call_to_call_erase(
    _span: Span,
    functions: &HashMap<Fun, Function>,
    datatypes: &HashMap<Path, Datatype>,
    resolved_call: &ResolvedCall,
) -> Result<CallErasure, VirErr> {
    Ok(match resolved_call {
        ResolvedCall::Spec => CallErasure::EraseTree,
        ResolvedCall::SpecAllowProofArgs => {
            CallErasure::Call(NodeErase::Erase, ExpectSpecArgs::AllPropagate)
        }
        ResolvedCall::Call(ufun, rfun, in_ghost) => {
            // Note: in principle, the unresolved function ufun should always be present,
            // but we currently allow external declarations of resolved trait functions
            // without a corresponding external trait declaration.
            let Some(f) = functions.get(ufun).or_else(|| functions.get(rfun)) else {
                dbg!(ufun, rfun);
                panic!("internal Verus error: could not find mode declarations for function")
            };
            if *in_ghost && f.x.mode == Mode::Exec {
                // This must be an autospec, so change exec -> spec
                CallErasure::Call(NodeErase::Erase, ExpectSpecArgs::AllYes)
            } else if f.x.mode == Mode::Spec {
                CallErasure::Call(NodeErase::Erase, ExpectSpecArgs::AllYes)
            } else {
                let args =
                    f.x.params
                        .iter()
                        .map(|p| match p.x.mode {
                            Mode::Spec => ExpectSpec::Yes,
                            Mode::Proof | Mode::Exec => ExpectSpec::No,
                        })
                        .collect::<Vec<_>>();
                CallErasure::Call(NodeErase::Keep, ExpectSpecArgs::PerArg(Arc::new(args)))
            }
        }
        ResolvedCall::Ctor(path, variant_name) => {
            let datatype = &datatypes[path];
            match &datatype.x.mode {
                Mode::Spec => {
                    CallErasure::Call(NodeErase::WhenExpectingSpec, ExpectSpecArgs::AllYes)
                }
                Mode::Proof | Mode::Exec => {
                    let variant = datatype.x.get_variant(variant_name);
                    let args = variant
                        .fields
                        .iter()
                        .map(|field| {
                            let (_, field_mode, _) = &field.a;
                            match field_mode {
                                Mode::Spec => ExpectSpec::Yes,
                                Mode::Proof | Mode::Exec => ExpectSpec::Propagate,
                            }
                        })
                        .collect::<Vec<_>>();
                    CallErasure::Call(
                        NodeErase::WhenExpectingSpec,
                        ExpectSpecArgs::PerArg(Arc::new(args)),
                    )
                }
            }
        }
        ResolvedCall::BracesCtor(path, variant_name, fields, has_tail) => {
            let datatype = &datatypes[path];
            match &datatype.x.mode {
                Mode::Spec => {
                    CallErasure::Call(NodeErase::WhenExpectingSpec, ExpectSpecArgs::AllYes)
                }
                Mode::Proof | Mode::Exec => {
                    let variant = datatype.x.get_variant(variant_name);
                    let mut args = fields
                        .iter()
                        .map(|field_name| {
                            let field = vir::ast_util::get_field(&variant.fields, field_name);
                            let (_, field_mode, _) = &field.a;
                            match field_mode {
                                Mode::Spec => ExpectSpec::Yes,
                                Mode::Proof | Mode::Exec => ExpectSpec::Propagate,
                            }
                        })
                        .collect::<Vec<_>>();
                    if *has_tail {
                        args.push(ExpectSpec::Propagate);
                    }
                    CallErasure::Call(
                        NodeErase::WhenExpectingSpec,
                        ExpectSpecArgs::PerArg(Arc::new(args)),
                    )
                }
            }
        }
        ResolvedCall::NonStaticExec => CallErasure::keep_all(),
        ResolvedCall::NonStaticProof(modes) => {
            let args = modes
                .iter()
                .map(|mode| match mode {
                    Mode::Spec => ExpectSpec::Yes,
                    Mode::Proof | Mode::Exec => ExpectSpec::No,
                })
                .collect::<Vec<_>>();
            CallErasure::Call(NodeErase::Keep, ExpectSpecArgs::PerArg(Arc::new(args)))
        }
        ResolvedCall::CompilableOperator(co) => match co {
            CompilableOperator::IntIntrinsic => {
                CallErasure::Call(NodeErase::Erase, ExpectSpecArgs::AllPropagate)
            }

            CompilableOperator::GhostExec => CallErasure::EraseTree,

            CompilableOperator::Implies
            | CompilableOperator::RcNew
            | CompilableOperator::ArcNew
            | CompilableOperator::BoxNew
            | CompilableOperator::SmartPtrClone { .. }
            | CompilableOperator::TrackedNew
            | CompilableOperator::TrackedExec => {
                CallErasure::Call(NodeErase::WhenExpectingSpec, ExpectSpecArgs::AllPropagate)
            }

            CompilableOperator::ClosureToFnProof(_)
            | CompilableOperator::TrackedExecBorrow
            | CompilableOperator::TrackedGet
            | CompilableOperator::TrackedBorrow
            | CompilableOperator::TrackedBorrowMut
            | CompilableOperator::Resolve
            | CompilableOperator::UseTypeInvariant => CallErasure::keep_all(),
        },
    })
}

pub(crate) fn setup_verus_ctxt_for_thir_erasure(
    verus_items: &VerusItems,
    erasure_hints: &ErasureHints,
) -> Result<(), VirErr> {
    let mut id_to_hir: HashMap<AstId, Vec<HirId>> = HashMap::new();
    for (hir_id, vir_id) in &erasure_hints.hir_vir_ids {
        if !id_to_hir.contains_key(vir_id) {
            id_to_hir.insert(*vir_id, vec![]);
        }
        id_to_hir.get_mut(vir_id).unwrap().push(*hir_id);
    }

    let mut vars = HashMap::<HirId, VarErasure>::new();
    for (span, mode) in erasure_hints.erasure_modes.var_modes.iter() {
        if crate::spans::from_raw_span(&span.raw_span).is_none() {
            continue;
        }
        if !id_to_hir.contains_key(&span.id) {
            dbg!(span);
            dbg!(mode);
        }
        for hir_id in &id_to_hir[&span.id] {
            vars.insert(*hir_id, mode_to_var_erase(*mode));
        }
    }

    let mut functions = HashMap::<Fun, Function>::new();
    for f in &erasure_hints.vir_crate.functions {
        functions.insert(f.x.name.clone(), f.clone()).map(|_| panic!("{:?}", &f.x.name));
    }

    let mut datatypes = HashMap::<Path, Datatype>::new();
    for d in &erasure_hints.vir_crate.datatypes {
        if let Dt::Path(path) = &d.x.name {
            datatypes.insert(path.clone(), d.clone()).map(|_| panic!("{:?}", &path));
        }
    }

    let mut calls = HashMap::<HirId, CallErasure>::new();
    for (hir_id, span_data, resolved_call) in &erasure_hints.resolved_calls {
        let span = span_data.span();
        calls.insert(
            *hir_id,
            resolved_call_to_call_erase(span, &functions, &datatypes, resolved_call)?,
        );
    }

    let mut bodies = HashMap::<LocalDefId, BodyErasure>::new();
    for (hir_id, c) in &erasure_hints.bodies {
        bodies.insert(*hir_id, *c);
    }

    let mut condition_spec = HashMap::<HirId, bool>::new();
    for (span, mode) in &erasure_hints.erasure_modes.condition_modes {
        let spec = matches!(mode, Mode::Spec);
        if crate::spans::from_raw_span(&span.raw_span).is_none() {
            continue;
        }
        if !id_to_hir.contains_key(&span.id) {
            dbg!(span, span.id);
            panic!("missing id_to_hir");
        }
        for hir_id in &id_to_hir[&span.id] {
            if condition_spec.contains_key(hir_id) {
                if condition_spec[hir_id] != spec {
                    panic!("inconsistent condition_modes: {:?}", span);
                }
            } else {
                condition_spec.insert(*hir_id, spec);
            }
        }
    }

    let verus_erasure_ctxt = VerusErasureCtxt {
        vars,
        calls,
        bodies,

        condition_spec,

        erased_ghost_value_fn_def_id: *verus_items
            .name_to_id
            .get(&VerusItem::ErasedGhostValue)
            .unwrap(),
        dummy_capture_struct_def_id: *verus_items
            .name_to_id
            .get(&VerusItem::DummyCapture(DummyCaptureItem::Struct))
            .unwrap(),
    };
    set_verus_erasure_ctxt(Arc::new(verus_erasure_ctxt));

    Ok(())
}

pub(crate) fn setup_verus_aware_ids(crate_items: &crate::external::CrateItems) {
    // Requirements:
    //  - If a function requires Verus-erasure, then it MUST be in the set
    //  - If a function has special properties (e.g., being const), that may cause Rust
    //    to run mir_borrowck on it before Verus mode-checking, then it MUST NOT be in the set.
    // For anything else: it doesn't matter.
    //
    // Since most consts are marked external, we can just use the VerusAware set for this.
    // We carve out exceptions for some special directives.

    let mut s = HashSet::<LocalDefId>::new();
    for item in crate_items.items.iter() {
        match &item.verif {
            crate::external::VerifOrExternal::VerusAware { const_directive, .. } => {
                if !*const_directive {
                    s.insert(item.id.owner_id().def_id);
                }
            }
            crate::external::VerifOrExternal::External { .. } => {}
        }
    }
    set_verus_aware_def_ids(Arc::new(s));
}
